import pickle
import numpy as np
from FlagEmbedding import FlagModel
from tqdm import tqdm
import jsonlines
import pdb
import torch
import sys
import os
from argparse import ArgumentParser


parser = ArgumentParser()
# parser.add_argument("--round", type=int, default=1)
parser.add_argument("--path", type=str, default="train")
args = parser.parse_args()

vectors = pickle.load(open('./knowledge_base/facts_embedding.pkl', 'rb'))
print("load vector done ...")
d, nb = 1024, len(vectors)
k = 4
model = FlagModel("bge-large-en-v1.5")
facts = [line.strip() for line in open("./knowledge_base/facts.txt", "r").readlines()]

vectors = torch.FloatTensor(vectors)
print("building index done...")

data = [d for d in jsonlines.open(args.path, "r")]
querys = []


for d in tqdm(data):
    Q = d["Q"] if "Q" in d else d["question"]
    Q = Q.replace("What happened as a result?", "").replace("What was the cause of this?", "").strip()
    if "O" in d or "options" in d:
        O = d["O"] if "O" in d else d["options"]
        options = " ".join(O).replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "").strip()
        embedding = model.encode(f"{Q} {options}").tolist()
    else:
        embedding = model.encode(Q).tolist()
    querys.append(embedding)

querys = torch.FloatTensor(querys)
fo = jsonlines.open(args.path.replace(".jsonl", "_retrieved.jsonl"), mode='w')

results = querys @ vectors.t()
for i, d in enumerate(tqdm(results)):
    _, topks = d.topk(k=4, dim=0, largest=True)
    data[i]['K'] = list(set([facts[topk] for topk in topks]))
    fo.write(data[i])
fo.close()


